In [1]:
import os
import pickle
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go

О данных, на которых делали пробинг грамматических категорий в BERT:¶

  1. en_ewt -- Genre: blog, social, reviews, email, web

  2. en_gum -- Genre: academic, blog, fiction, government, news, nonfiction, social, spoken, web, wiki

Какие категории были:¶

  • Общие для двух датасетов
  1. Case
  • Nom (The point was; He said; I said)
  • Gen (Despite Daniel' s attempts; Blood smeared the back of his hand)
  • Acc (Wash the leaves; Got it)

    То есть тут существительные и местоимения.

  1. Gender
  • Masc (Dvořák ' s first piece; He / his)
  • Fem (Kiara Perkins; She joined)
  • Neut (In other words; It is unsurprising)

    Тут существительные собственные и нарицательные, местоимения.

  1. Definite
  • Ind (When regarding an artwork; a)
  • Def (I shall leave this office; the)

    Тут артикли и указательные местоимения.

  1. Number
  • Sing (As a foreign visitor; Gonna rip your purse; Seen this before; I know it)
  • Plur (Six months; See what else we got here; Let' s just double check that)

    Тут существительные, местоимения, глаголы в каких то случаях кажется.

  1. Person
  • 1 (I / We)
  • 2 (You / Your; Just be aware)
  • 3 (всякое разное)

    Тут местоимения, глагол.

  1. Degree
  • Pos (Although useful; Some disciplines; In computer science)
  • Cmp (much smaller; more details; sooner or later; explain more)
  • Sup (On the hottest days; Most of the rooms)

    Тут прилагательные, неопределенное местоимение, наречия (?).

  1. NumForm
  • Digit (13; + всякая каша; Ø -- символы)
  • Word (one; five; + каша)
  • Roman (I, II)

    Числительные.

  1. NumType
  • Ord (first; 17th; + каша)
  • Card (1; 2014; one)
  • Mult (once; twice) -- мало их
  • Frac (2.4) -- мало их

    Числительные.

  1. Mood
  • Ind
  • Imp (Press okay)
  • Sub (только в EWT и очень мало)

    Наклонение.

  1. Tense
  • Past
  • Pres

    Кажется тут только время глагола.

  1. VerbForm
  • Fin
  • Part
  • Inf
  • Ger
  1. PronType
  • Neg
  • Dem
  • Tot
  • Rel
  • Int
  • Prs
  • Ind
  • Art
  • Различающиеся категории датасетов
  1. ExtPos (EWT)
  • ADV
  • CCONJ
  • PRON
  • SCONJ
  • ADP (figures such as the Pope (such = ADJ))

Как делали:¶

При помощи NeuroX tooklit

  1. Даем conllu файлы AIRI Probing_framework
  2. Конвертируем .csv файлы в .txt формат и делаем разные файлы для train и test.
  • ограничиваю длину предложения: 3 < len(sentence.split()) < 35
  • беру максимум 2600 предложений для train и 900 для test (эмбеддинги очень много весят) в каждой категории.
  • делаю balanced-sampling, чтобы отбирать одинаковое количество предложений для каждого лейбла в грамматической категории

! Проблема -- недостаточное количество данных для некоторых меток. Возможно, отчасти из-за ограничений на длину предложения, но также потому что их просто недостаточно. Гипотетическое решение: подумать, как вытаскивать эти данные из неиспользуемых train-данных.

  • NeuroX дает эмбеддинги для каждого предложения в train и test.
  1. Подаю данные в NeuroX:
  • linear_probe.train_logistic_regression_probe -- тренируем
  • linear_probe.evaluate_probe -- получаю accuracy для train и test. Гипотетически можно поставить F1, среднее accuracy и F1, pearson.
  • получаю ранжирование по нейронам.

Загружаем метрики и ранжирование нейронов¶

In [2]:
path = os.path.abspath('../')+'/results/'
In [3]:
with open(f'{path}scores_ewt.pkl', 'rb') as f:
    scores_ewt = pickle.load(f)
    
with open(f'{path}scores_c_ewt.pkl', 'rb') as f:
    scores_c_ewt = pickle.load(f)
    
with open(f'{path}neurons_ewt.pkl', 'rb') as f:
    ordered_neurons_ewt = pickle.load(f)

with open(f'{path}neurons_gum.pkl', 'rb') as f:
    ordered_neurons_gum = pickle.load(f)
    
with open(f'{path}size_ewt.pkl', 'rb') as f:
    size_ewt = pickle.load(f)
    
with open(f'{path}scores_gum.pkl', 'rb') as f:
    scores_gum = pickle.load(f)
    
with open(f'{path}scores_c_gum.pkl', 'rb') as f:
    scores_c_gum = pickle.load(f)
    
with open(f'{path}size_gum.pkl', 'rb') as f:
    size_gum = pickle.load(f)

with open(f'{path}scores_keep_ewt.pkl', 'rb') as f:
    scores_keep_ewt = pickle.load(f)
    
with open(f'{path}scores_keep_gum.pkl', 'rb') as f:
    scores_keep_gum = pickle.load(f)
    
with open(f'{path}scores_remove_ewt.pkl', 'rb') as f:
    scores_remove_ewt = pickle.load(f)
    
with open(f'{path}scores_remove_gum.pkl', 'rb') as f:
    scores_remove_gum = pickle.load(f)
In [4]:
def bad_scores(scores):
    
    for k, v in scores.items():
        m1 = []
        m2 = []

        for i, j in v[0].items():

            if v[0][i] < 0.5:

                if not m1.__contains__(v[0]):
                    m1.append(v[0])

                print(k, i, 'train_score', v[0][i])
                
        if m1:
            print(k, 'train', m1)
            print('---------------------')

        for i, j in v[1].items():
            
            if v[1][i] < 0.5:

                if not m2.__contains__(v[1]):
                    m2.append(v[1])

                print(k, i, 'test_score', v[1][i])
                
        if m2:
            print(k, 'test', m2)
            print('---------------------')

Плохая классификация для датасета en_ewt (accuracy < 0.5)¶

Genre: blog, social, reviews, email, web

In [5]:
bad_scores(scores_ewt)
Definite __OVERALL__ test_score 0.2611111111111111
Definite Ind test_score 0.2866666666666667
Definite Def test_score 0.23555555555555555
Definite test [{'__OVERALL__': 0.2611111111111111, 'Ind': 0.2866666666666667, 'Def': 0.23555555555555555}]
---------------------
NumType Ord test_score 0.3
NumType Frac test_score 0.0
NumType Mult test_score 0.2222222222222222
NumType test [{'__OVERALL__': 0.7949640287769785, 'Ord': 0.3, 'Card': 0.92, 'Frac': 0.0, 'Mult': 0.2222222222222222}]
---------------------
PronType Rel test_score 0.4642857142857143
PronType test [{'__OVERALL__': 0.7123552123552124, 'Int': 0.9017857142857143, 'Prs': 0.8125, 'Neg': 0.8333333333333334, 'Tot': 0.6, 'Dem': 0.6607142857142857, 'Art': 0.5714285714285714, 'Rel': 0.4642857142857143, 'Ind': 0.5806451612903226}]
---------------------
NumForm Roman train_score 0.4
NumForm train [{'__OVERALL__': 0.9446415897799858, 'Roman': 0.4, 'Digit': 1.0, 'Word': 0.8686514886164624}]
---------------------
NumForm Roman test_score 0.0
NumForm test [{'__OVERALL__': 0.8173302107728337, 'Roman': 0.0, 'Digit': 0.9266666666666666, 'Word': 0.5725806451612904}]
---------------------
Mood Sub test_score 0.0
Mood test [{'__OVERALL__': 0.9287128712871288, 'Imp': 0.9504950495049505, 'Sub': 0.0, 'Ind': 0.9233333333333333}]
---------------------
ExtPos ADP test_score 0.2222222222222222
ExtPos CCONJ test_score 0.25
ExtPos PRON test_score 0.0
ExtPos SCONJ test_score 0.0
ExtPos test [{'__OVERALL__': 0.6285714285714286, 'ADP': 0.2222222222222222, 'ADV': 0.95, 'CCONJ': 0.25, 'PRON': 0.0, 'SCONJ': 0.0}]
---------------------
In [6]:
def accuracy_plot(dct_acc, dct_data):
    cats=[k for k in dct_acc.keys()]
    assert [k for k in dct_acc.keys()] == [k for k in dct_data.keys()]
    accuracy_train=[round(v[0]['__OVERALL__'], 2) for k, v in dct_acc.items()]
    accuracy_test = [round(v[1]['__OVERALL__'], 2) for k, v in dct_acc.items()]
    train=[v[0] for k, v in dct_data.items()]
    test = [v[1] for k, v in dct_data.items()]
    num_classes = [v[2] for k, v in dct_data.items()]
    d = pd.DataFrame({'categories': cats, 'train_acc' : accuracy_train, 'test_acc': accuracy_test,
                      'train_data': train, 'test_data': test, 'num_classes': num_classes})    
    fig1 = px.line(d, x='categories', y=['train_acc', 'test_acc'], template="plotly_white") 
    fig2 = px.line(d, x='categories', y=['train_data', 'test_data', 'num_classes'], template= "seaborn") 
    fig3 = px.bar(d, x='categories', y='num_classes', template="plotly") 
    fig3.update_traces(texttemplate='%{y}', textposition='outside')
    fig3.update_layout(uniformtext_minsize=8, uniformtext_mode='hide')
    fig = make_subplots(rows=3, cols=1, subplot_titles=('Accuracy', 'Data size', 'Num classes'))
    fig.add_trace(fig1['data'][0], row=1, col=1)
    fig.add_trace(fig1['data'][1], row=1, col=1)
    fig.add_trace(fig2['data'][0], row=2, col=1)
    fig.add_trace(fig2['data'][1], row=2, col=1)
    fig.add_trace(fig3['data'][0], row=3, col=1)
    fig.update_layout(height=1400, width=900)
    fig.update_layout(showlegend=False)
    
    fig.show()
In [7]:
accuracy_plot(scores_ewt, size_ewt)

EWT Control Task¶

In [8]:
def accuracy(dct_acc):
    cats=[k for k in dct_acc.keys()]
    accuracy_train=[round(v[0]['__OVERALL__'], 2) for k, v in dct_acc.items()]
    accuracy_test = [round(v[1]['__OVERALL__'], 2) for k, v in dct_acc.items()]
    d = pd.DataFrame({'categories': cats, 'train_acc' : accuracy_train, 'test_acc': accuracy_test})    
    fig = px.line(d, x='categories', y=['train_acc', 'test_acc'], template="plotly_white") 
    fig.show()
In [9]:
accuracy(scores_c_ewt)

Плохая классификация для датасета en_gum¶

Genre: academic, blog, fiction, government, news, nonfiction, social, spoken, web, wiki

In [10]:
bad_scores(scores_gum)
PronType __OVERALL__ test_score 0.3849129593810445
PronType Rel test_score 0.027777777777777776
PronType Ind test_score 0.047619047619047616
PronType Dem test_score 0.16071428571428573
PronType Neg test_score 0.0
PronType Art test_score 0.0
PronType test [{'__OVERALL__': 0.3849129593810445, 'Prs': 0.7678571428571429, 'Rel': 0.027777777777777776, 'Ind': 0.047619047619047616, 'Int': 0.875, 'Dem': 0.16071428571428573, 'Neg': 0.0, 'Art': 0.0, 'Tot': 0.6122448979591837}]
---------------------
VerbForm __OVERALL__ test_score 0.20310296191819463
VerbForm Fin test_score 0.09777777777777778
VerbForm Inf test_score 0.23555555555555555
VerbForm Part test_score 0.24
VerbForm Ger test_score 0.4411764705882353
VerbForm test [{'__OVERALL__': 0.20310296191819463, 'Fin': 0.09777777777777778, 'Inf': 0.23555555555555555, 'Part': 0.24, 'Ger': 0.4411764705882353}]
---------------------
Degree Cmp test_score 0.3111111111111111
Degree test [{'__OVERALL__': 0.5471204188481675, 'Cmp': 0.3111111111111111, 'Pos': 0.5733333333333334, 'Sup': 0.6216216216216216}]
---------------------
NumType __OVERALL__ test_score 0.4523076923076923
NumType Frac test_score 0.20833333333333334
NumType Ord test_score 0.0
NumType Mult test_score 0.2857142857142857
NumType test [{'__OVERALL__': 0.4523076923076923, 'Frac': 0.20833333333333334, 'Card': 0.6222222222222222, 'Ord': 0.0, 'Mult': 0.2857142857142857}]
---------------------
NumForm Roman test_score 0.0
NumForm test [{'__OVERALL__': 0.7682619647355163, 'Roman': 0.0, 'Word': 0.7134146341463414, 'Digit': 0.8209606986899564}]
---------------------
In [11]:
accuracy_plot(scores_gum, size_gum)

GUM Control Task¶

In [12]:
accuracy(scores_c_gum)

Посмотреть на N нейронов¶

В чем суть: пока простой тупой метод посмотреть общий set(нейронов), если для каждой категории выбираем N-top нейронов в ранжировании.

Всего нейронов в BERT 9984 (13 слоев * 768 -- размерность эмбеддинга).

Всего в каждом датасете было 12 категорий.

In [13]:
def get_overall_common_neurons(dct, nn=[50, 100, 150, 200, 250, 300, 350, 400, 450, 500]):
    d1, d2 = [], []
    for n in nn:
        neurons_list = []
        for k, v in dct.items():
            v = v[:n]
            neurons_list+=v
        
        d1.append(len(neurons_list))
        d2.append(len(set(neurons_list)))
        
    d = pd.DataFrame({f'top N-neuron for all {len(dct.keys())} categories': nn, 'all' : d1, 'unique': d2})    
    fig = px.bar(d, x=f'top N-neuron for all {len(dct.keys())} categories', y=['unique', 'all'], template="plotly_white", barmode='group') 
    fig.update_traces(texttemplate='%{y}', textposition='outside')
    fig.update_layout(uniformtext_minsize=8, uniformtext_mode='hide')
    fig.update_xaxes(tick0=0, dtick=50)
    fig.show()

EWT¶

In [14]:
get_overall_common_neurons(ordered_neurons_ewt) 

GUM¶

In [15]:
get_overall_common_neurons(ordered_neurons_gum) 

Посмотрим на пересекающиеся нейроны в двух датасетах¶

  • Отдельно по каждой категории
In [16]:
def common_neurons(d1, d2, nn=[50, 100, 150, 200, 250, 300, 350, 400, 450, 500]):
    common_cats = []
    for k1 in d1.keys():
        for k2 in d2.keys():   
            if k1 == k2:
                common_cats.append(k1)
            
    df = pd.DataFrame(index=nn, columns=common_cats)
    df = df.fillna(0)
    
    for cat in common_cats:
        common_neurons = []
        for n in nn:
            p = set(d1[cat][:n]) & set(d2[cat][:n])
            common_neurons.append(len(p))
        df[cat] = common_neurons
    return df        
In [17]:
df = common_neurons(ordered_neurons_ewt, ordered_neurons_gum)
df
Out[17]:
Degree Definite NumType PronType NumForm VerbForm Number Person Case Mood Tense Gender
50 1 3 0 2 1 1 3 5 2 5 5 13
100 4 7 3 2 1 4 8 11 4 10 11 21
150 5 11 4 4 2 7 17 16 10 19 15 30
200 8 17 11 7 4 13 18 24 14 25 23 39
250 11 23 15 11 9 26 24 34 16 36 33 47
300 16 38 23 18 16 31 35 41 19 46 45 66
350 26 47 30 22 21 42 39 50 23 57 55 84
400 36 61 35 29 28 51 46 59 29 71 65 96
450 43 66 43 35 35 64 56 74 37 79 73 113
500 54 78 49 45 41 79 67 81 45 94 87 127
In [18]:
def visualise(df):
    
    CAT = list(df.columns)

    a = 4  # number of rows
    b = 3  # number of columns
    c = 1  # initialize plot counter

    fig = plt.figure(figsize=(10,8))
    col_map = plt.get_cmap('Paired')
    for i in CAT:
        plt.subplot(a, b, c)
        plt.title(f'{i}')
        plt.xlabel('N neurons') 
        plt.bar(list(df.index), list(df[i]), width=30, bottom=0, 
                color=col_map.colors, edgecolor='k', linewidth=1)
        c = c + 1

    plt.tight_layout()
    plt.show()
In [19]:
visualise(df)

Ablation¶

Keep Top 100¶

In [20]:
accuracy(scores_keep_ewt)

Keep Bottom 100¶

In [21]:
accuracy(scores_remove_ewt)

Keep Top 100¶

In [22]:
accuracy(scores_keep_gum)

Keep Bottom 100¶

In [23]:
accuracy(scores_remove_gum)